import os
import json
import argparse
from datetime import datetime
from tqdm import tqdm
from Heuristic_Retrieval import Heuristic_Retrieval_Benchmark
from Heuristic_RAG_Retrieval import Heuristic_RAG_Retrieval_Benchmark
from Keyword_Retrieval import Keyword_Retrieval_Benchmark

OPTIONS = {
    'Heuristic': Heuristic_Retrieval_Benchmark,
    'Keyword': Keyword_Retrieval_Benchmark, 
    'Heuristic_RAG': Heuristic_RAG_Retrieval_Benchmark,
}
input_raw_data_path = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/repo_list_76.json'
output_benchmark_path = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/retrieval_benchmark_76.json'
cloned_repos_dir = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/cloned_repos'

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--benchmark', type=str, default='Heuristic', choices=['Heuristic', 'Keyword', 'Heuristic_RAG'], help="The retrieval benchmark to use.")
    parser.add_argument('--input_raw_data_path', type=str, default=input_raw_data_path)
    parser.add_argument('--output_benchmark_path', type=str, default=output_benchmark_path)
    parser.add_argument('--cloned_repos_dir', type=str, default=cloned_repos_dir)
    parser.add_argument("--pre_computed_retrieval_results_path", type=str, default=None, help="Path to pre-computed retrieval results.")

    args = parser.parse_args()
    
    datetime_str = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    

    output_retrieval_results_file_path = f'/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/experiment_results/retrieval_results/{args.benchmark.lower()}/{args.benchmark.lower()}_retrieval_results_{datetime_str}.json'


    
    retrieval_benchmark_class = OPTIONS[args.benchmark](input_raw_data_path = args.input_raw_data_path, output_benchmark_path = args.output_benchmark_path, cloned_repos_dir = args.cloned_repos_dir, output_retrieval_results_file_path = output_retrieval_results_file_path, pre_computed_benchmark_file_path=None, pre_computed_retrieval_results_path=args.pre_computed_retrieval_results_path, refine_times = 3,multi_processing = False)
    
if __name__ == '__main__':
    main()